Skip to content

Conversation

@abhijeet-dhumal
Copy link
Contributor

@abhijeet-dhumal abhijeet-dhumal commented Sep 9, 2025

Description

RHOAIENG-32470
An example demonstrating Kubeflow Trainer V2 features and terminologies
TRL based supervised fine-tuning Peft-LoRA with GPT-2 and Alpaca dataset for instruction following

How Has This Been Tested?

Tested notebook by running thoroughly..

Merge criteria:

  • The commits are squashed in a cohesive manner and have meaningful messages.
  • Testing instructions have been added in the PR body (for PRs involving changes that are not immediately obvious).
  • The developer has manually tested the changes and verified that the changes work

Summary by CodeRabbit

  • New Features

    • Added ready-to-run Kubernetes examples for a multi-step, 2-node PyTorch training runtime with shared workspace and job sequencing.
    • Added a shared ReadWriteMany PersistentVolumeClaim example.
    • Added distributed Fashion‑MNIST training script with checkpointing, resume support, and live progression tracking.
    • Added TRL (GPT‑2) training entrypoint with LoRA support, distributed execution, checkpointing/resume, and graceful shutdown handling.
  • Documentation

    • Added README and notebook demonstrating end-to-end TRL training, runtime setup, job submission, monitoring, and cleanup.

@openshift-ci
Copy link

openshift-ci bot commented Sep 9, 2025

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign kryanbeane for approval. For more information see the Code Review Process.

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@coderabbitai
Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds Kubeflow Trainer V2 example assets: a ClusterTrainingRuntime manifest (torch-cuda-custom) and a shared PVC (workspace), two distributed training scripts (Fashion‑MNIST and TRL/LoRA GPT‑2) with checkpointing and JSON progression tracking, plus a TRL notebook and README documenting usage.

Changes

Cohort / File(s) Summary of Changes
Manifests (runtime & PVC)
examples/kft-v2/manifests/cluster_training_runtime.yaml, examples/kft-v2/manifests/shared_pvc.yaml
Adds a ClusterTrainingRuntime resource torch-cuda-custom (trainer.kubeflow.org/v1alpha1) configuring a 2‑node PyTorch runtime with three replicatedJobs (dataset-initializer → model-initializer → node) and a shared workspace PVC (ReadWriteMany, 50Gi, storageClassName: nfs-csi, Filesystem).
Training Scripts
examples/kft-v2/scripts/mnist.py, examples/kft-v2/scripts/trl_training.py
New entry points train_fashion_mnist() and trl_train() implementing distributed setup via PET env vars, progression tracking (JSON status), checkpoint discovery/resume, multi-node/multi-GPU DDP handling, and graceful checkpointing on signals; includes data/model loaders, callbacks, and training loops.
Notebook & Docs
examples/kft-v2/trl-gpt2-checkpointing.ipynb, examples/kft-v2/README.md
Adds a TRL/GPT-2 checkpointing demo notebook showing runtime enumeration, TrainJob submission/monitoring, and cleanup; adds a README describing requirements, quick-starts, and feature notes for the examples.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant K8s as Kubernetes API
  participant CTR as ClusterTrainingRuntime
  participant DS as dataset-initializer
  participant MI as model-initializer
  participant TR as trainer node

  User->>K8s: kubectl apply PVC + ClusterTrainingRuntime
  K8s->>CTR: create runtime resource (torch-cuda-custom)
  CTR->>DS: start replicatedJob dataset-initializer
  DS-->>CTR: Completed
  CTR->>MI: start replicatedJob model-initializer (dependsOn dataset Complete)
  MI-->>CTR: Completed
  CTR->>TR: start replicatedJob node (dependsOn model Complete)
  TR-->>CTR: Completed
Loading
sequenceDiagram
  autonumber
  participant Proc as Trainer Process
  participant Dist as torch.distributed
  participant FS as Shared PVC (/workspace)
  participant PT as ProgressionTracker
  participant CK as Checkpoint Storage

  Proc->>Dist: read PET_* env, init process group (gloo/nccl)
  Proc->>FS: load dataset/model from initializers or remote
  Proc->>CK: discover latest checkpoint (resume if present)
  loop Training loop
    Proc->>PT: update_step (loss, lr, metrics)
    alt epoch boundary
      Proc->>CK: save checkpoint (epoch-*.pth / save_pretrained)
      Proc->>PT: update_epoch (avg loss, accuracy, checkpoint meta)
    end
  end
  Proc->>PT: write_status(Completed)
  Proc->>Dist: destroy_process_group
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I hop through mounts and mounted drives,
I stitch up weights and count my strides.
Checkpoints hum and JSON glows,
Nodes align in ordered rows.
A rabbit saves when signals sing—trained paws, ready spring! 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 34.21% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly captures the main change—adding a Kubeflow Trainer V2 demonstration—and directly aligns with the examples, scripts, and manifests introduced in this pull request, making it clear and specific.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

🧹 Nitpick comments (18)
examples/kft-v2/manifests/shared_pvc.yaml (1)

12-12: Add missing trailing newline.

Fix YAML lint error (new-line-at-end-of-file).

-  volumeMode: Filesystem
+  volumeMode: Filesystem
+
examples/kft-v2/manifests/cluster_training_runtime.yaml (2)

116-119: Persist progression file on the shared volume.

Writing to /tmp is ephemeral; put the progression JSON on /workspace so it persists and is visible to controllers.

-                        - name: TRAINJOB_PROGRESSION_FILE_PATH
-                          value: /tmp/training_progression.json
+                        - name: TRAINJOB_PROGRESSION_FILE_PATH
+                          value: /workspace/training_progression.json

106-115: NCCL env on CPU-only runs is unnecessary.

If this runtime is CPU-only, consider dropping NCCL_* envs to simplify. The script already selects gloo on CPU.

examples/kft-v2/scripts/mnist.py (5)

58-61: Fix type hints to use Optional[...] per PEP 484.

Adjust signatures to satisfy Ruff RUF013.

-            loss: float = None,
-            learning_rate: float = None,
-            checkpoint_dir: str = None,
+            loss: Optional[float] = None,
+            learning_rate: Optional[float] = None,
+            checkpoint_dir: Optional[str] = None,
@@
-        def update_epoch(self, epoch: int, checkpoint_dir: str = None, **metrics):
+        def update_epoch(self, epoch: int, checkpoint_dir: Optional[str] = None, **metrics):

Also applies to: 127-127


1-1: Remove shebang or mark file executable.

To silence EXE001, either remove the shebang (common for importable modules) or set execute bit. Removing is simplest.

-#!/usr/bin/env python3

45-46: Default progression file should live on a mounted volume, not /tmp.

Align default with the manifest so status persists and is observable.

-            self.status_file_path = status_file_path or os.getenv(
-                "TRAINJOB_PROGRESSION_FILE_PATH", "/tmp/training_progression.json"
+            self.status_file_path = status_file_path or os.getenv(
+                "TRAINJOB_PROGRESSION_FILE_PATH", "/workspace/training_progression.json"

471-471: Avoid modifying grads when reporting grad norm.

clip_grad_norm_ changes gradients; compute norm without clipping or move clipping before optimizer.step if intended.

Add a helper and use it here:

def compute_grad_norm(parameters, norm_type=2.0):
    params = [p for p in parameters if p.grad is not None]
    if not params:
        return 0.0
    if norm_type == float("inf"):
        return max(p.grad.detach().abs().max().item() for p in params)
    total = 0.0
    for p in params:
        param_norm = p.grad.detach().data.norm(norm_type)
        total += param_norm.item() ** norm_type
    return total ** (1.0 / norm_type)
-                        grad_norm=f"{torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0):.4f}"
+                        grad_norm=f"{compute_grad_norm(model.parameters()):.4f}"

224-225: Narrow overly broad exception handlers.

Catching bare Exception hides actionable errors. Scope to expected exceptions (e.g., RuntimeError for DDP init; OSError/ValueError for file ops).

Also applies to: 285-286, 340-341

examples/kft-v2/scripts/trl_training.py (7)

556-558: Remove unused variable.

checkpoint_interval is never used.

-        checkpoint_interval = os.getenv('CHECKPOINT_INTERVAL', '30s')

606-606: Drop redundant import.

os is already imported at Line 4.

-    import os

272-274: Don’t swallow exceptions silently.

Replace bare except/pass with minimal logging.

-            except Exception as e:
-                pass
+            except Exception as e:
+                print(f"Failed to write simple progress: {e}")

65-65: Consistent step indexing (+1).

mnist.ProgressionTracker uses 1-based steps; mirror for consistency.

-                self.current_step = (epoch - 1) * self.steps_per_epoch + step
+                self.current_step = (epoch - 1) * self.steps_per_epoch + step + 1

51-52: Progress file placement and durability.

Defaulting to /tmp risks loss on pod restart. Prefer a PVC path like CHECKPOINT_URI or a dedicated PROGRESSION_URI on RWX storage.

Option: set TRAINJOB_PROGRESSION_FILE_PATH to f"{CHECKPOINT_URI}/training_progression.json" and ensure directory exists before writing.

Also applies to: 181-185, 270-271, 774-782


663-669: Minor: remove f-prefix on constant strings.

-                print(f"Applied DDP static graph fix for distributed training")
+                print("Applied DDP static graph fix for distributed training")
@@
-                    print(f"PEFT model detected, DDP parameters properly configured")
+                    print("PEFT model detected, DDP parameters properly configured")

37-187: Deduplicate utilities.

ProgressionTracker and setup_distributed duplicate mnist.py versions. Factor into a shared helper (e.g., examples/kft-v2/scripts/common.py) for reuse.

I can extract and wire a common module if helpful.

Also applies to: 424-454

examples/kft-v2/trl-gpt2-checkpointing.ipynb (3)

624-626: Progress file path on PVC instead of /tmp.

Use a RWX path so UIs can read it across restarts.

Example:

-    "        checkpoint_dir = Path(os.getenv('CHECKPOINT_URI', '/workspace/checkpoints'))\n",
+    "        checkpoint_dir = Path(os.getenv('CHECKPOINT_URI', '/workspace/checkpoints'))\n",
+    "        os.environ.setdefault('TRAINJOB_PROGRESSION_FILE_PATH', str(checkpoint_dir / 'training_progression.json'))\n",

And in env:

-    "    \"TRAINJOB_PROGRESSION_FILE_PATH\": \"/tmp/training_progression.json\",\n",
+    "    \"TRAINJOB_PROGRESSION_FILE_PATH\": \"/workspace/checkpoints/training_progression.json\",\n",

Also applies to: 916-917, 833-833


768-818: Reuse shared utilities to avoid drift.

ProgressionTracker and setup_distributed duplicate the script and mnist.py. Extract to a common module.

I can factor these into examples/kft-v2/scripts/common.py and update imports.


955-955: Minor: remove f-strings without placeholders.

-    "                print(f\"Applied DDP static graph fix for distributed training\")\n",
+    "                print(\"Applied DDP static graph fix for distributed training\")\n",
@@
-    "                    print(f\"PEFT model detected, DDP parameters properly configured\")\n",
+    "                    print(\"PEFT model detected, DDP parameters properly configured\")\n",
@@
-    "    \"print(f\\\"Trainjob submitted!!\\\")\"\n",
+    "    \"print(\\\"Trainjob submitted!!\\\")\"\n",

Also applies to: 963-963, 735-741

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0235a63 and 09204ae.

⛔ Files ignored due to path filters (4)
  • examples/kft-v2/dist/kubeflow_trainer_api-2.0.0-py3-none-any.whl is excluded by !**/dist/**
  • examples/kft-v2/docs/jobs.png is excluded by !**/*.png
  • examples/kft-v2/docs/trainjob_pods.png is excluded by !**/*.png
  • examples/kft-v2/docs/trainjobs_jobsets.png is excluded by !**/*.png
📒 Files selected for processing (5)
  • examples/kft-v2/manifests/cluster_training_runtime.yaml (1 hunks)
  • examples/kft-v2/manifests/shared_pvc.yaml (1 hunks)
  • examples/kft-v2/scripts/mnist.py (1 hunks)
  • examples/kft-v2/scripts/trl_training.py (1 hunks)
  • examples/kft-v2/trl-gpt2-checkpointing.ipynb (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/kft-v2/scripts/trl_training.py (1)
examples/kft-v2/scripts/mnist.py (7)
  • ProgressionTracker (21-225)
  • update_step (54-125)
  • get_checkpoint_number (87-97)
  • get_checkpoint_number (155-164)
  • write_status (181-225)
  • update_epoch (127-179)
  • setup_distributed (246-290)
examples/kft-v2/scripts/mnist.py (2)
examples/kft-v2/scripts/trl_training.py (7)
  • ProgressionTracker (37-187)
  • update_step (60-105)
  • get_checkpoint_number (80-84)
  • get_checkpoint_number (128-132)
  • write_status (143-187)
  • update_epoch (107-141)
  • setup_distributed (424-453)
tests/kfto/resources/mnist.py (1)
  • train (136-140)
🪛 YAMLlint (1.37.1)
examples/kft-v2/manifests/shared_pvc.yaml

[error] 12-12: no new line character at the end of file

(new-line-at-end-of-file)

🪛 Ruff (0.12.2)
examples/kft-v2/scripts/trl_training.py

51-51: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


107-107: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


169-169: Multiple statements on one line (colon)

(E701)


170-170: Multiple statements on one line (colon)

(E701)


171-171: Multiple statements on one line (colon)

(E701)


172-172: Multiple statements on one line (colon)

(E701)


186-186: Do not catch blind exception: Exception

(BLE001)


191-191: Multiple statements on one line (colon)

(E701)


192-192: Multiple statements on one line (colon)

(E701)


241-241: Do not catch blind exception: Exception

(BLE001)


272-273: try-except-pass detected, consider logging the exception

(S110)


272-272: Do not catch blind exception: Exception

(BLE001)


272-272: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


283-283: Do not catch blind exception: Exception

(BLE001)


291-291: Do not catch blind exception: Exception

(BLE001)


295-295: Unused method argument: signum

(ARG002)


295-295: Unused method argument: frame

(ARG002)


302-302: Unused method argument: args

(ARG002)


302-302: Unused method argument: state

(ARG002)


302-302: Unused method argument: control

(ARG002)


302-302: Unused method argument: kwargs

(ARG002)


318-318: Do not catch blind exception: Exception

(BLE001)


321-321: Unused method argument: kwargs

(ARG002)


370-370: Unused method argument: args

(ARG002)


370-370: Unused method argument: control

(ARG002)


370-370: Unused method argument: kwargs

(ARG002)


383-383: Unused method argument: args

(ARG002)


383-383: Unused method argument: control

(ARG002)


383-383: Unused method argument: kwargs

(ARG002)


401-401: Unused method argument: control

(ARG002)


401-401: Unused method argument: kwargs

(ARG002)


417-417: Do not catch blind exception: Exception

(BLE001)


450-450: Do not catch blind exception: Exception

(BLE001)


470-470: Consider moving this statement to an else block

(TRY300)


471-471: Do not catch blind exception: Exception

(BLE001)


508-508: Consider moving this statement to an else block

(TRY300)


510-510: Do not catch blind exception: Exception

(BLE001)


556-556: Local variable checkpoint_interval is assigned to but never used

Remove assignment to unused variable checkpoint_interval

(F841)


606-606: Redefinition of unused os from line 4

Remove definition: os

(F811)


614-614: Do not catch blind exception: Exception

(BLE001)


663-663: f-string without any placeholders

Remove extraneous f prefix

(F541)


668-668: f-string without any placeholders

Remove extraneous f prefix

(F541)


669-669: Do not catch blind exception: Exception

(BLE001)


692-692: Do not catch blind exception: Exception

(BLE001)


730-730: f-string without any placeholders

Remove extraneous f prefix

(F541)


738-738: Do not catch blind exception: Exception

(BLE001)


754-754: Do not catch blind exception: Exception

(BLE001)


770-770: Use raise without specifying exception name

Remove exception name

(TRY201)


787-787: Do not catch blind exception: Exception

(BLE001)

examples/kft-v2/scripts/mnist.py

1-1: Shebang is present but file is not executable

(EXE001)


45-45: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


58-58: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


59-59: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


127-127: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


224-224: Do not catch blind exception: Exception

(BLE001)


285-285: Do not catch blind exception: Exception

(BLE001)


340-340: Do not catch blind exception: Exception

(BLE001)

examples/kft-v2/trl-gpt2-checkpointing.ipynb

61-61: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


70-70: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


70-70: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


70-70: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


118-118: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


181-181: Multiple statements on one line (colon)

(E701)


182-182: Multiple statements on one line (colon)

(E701)


183-183: Multiple statements on one line (colon)

(E701)


184-184: Multiple statements on one line (colon)

(E701)


198-198: Do not catch blind exception: Exception

(BLE001)


204-204: Multiple statements on one line (colon)

(E701)


205-205: Multiple statements on one line (colon)

(E701)


253-253: Do not catch blind exception: Exception

(BLE001)


284-285: try-except-pass detected, consider logging the exception

(S110)


284-284: Do not catch blind exception: Exception

(BLE001)


284-284: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


295-295: Do not catch blind exception: Exception

(BLE001)


303-303: Do not catch blind exception: Exception

(BLE001)


307-307: Unused method argument: signum

(ARG002)


307-307: Unused method argument: frame

(ARG002)


314-314: Unused method argument: args

(ARG002)


314-314: Unused method argument: state

(ARG002)


314-314: Unused method argument: control

(ARG002)


314-314: Unused method argument: kwargs

(ARG002)


330-330: Do not catch blind exception: Exception

(BLE001)


333-333: Unused method argument: kwargs

(ARG002)


382-382: Unused method argument: args

(ARG002)


382-382: Unused method argument: control

(ARG002)


382-382: Unused method argument: kwargs

(ARG002)


395-395: Unused method argument: args

(ARG002)


395-395: Unused method argument: control

(ARG002)


395-395: Unused method argument: kwargs

(ARG002)


413-413: Unused method argument: control

(ARG002)


413-413: Unused method argument: kwargs

(ARG002)


429-429: Do not catch blind exception: Exception

(BLE001)


462-462: Do not catch blind exception: Exception

(BLE001)


482-482: Consider moving this statement to an else block

(TRY300)


483-483: Do not catch blind exception: Exception

(BLE001)


520-520: Consider moving this statement to an else block

(TRY300)


522-522: Do not catch blind exception: Exception

(BLE001)


568-568: Local variable checkpoint_interval is assigned to but never used

Remove assignment to unused variable checkpoint_interval

(F841)


616-616: Redefinition of unused os from line 14

Remove definition: os

(F811)


624-624: Do not catch blind exception: Exception

(BLE001)


672-672: f-string without any placeholders

Remove extraneous f prefix

(F541)


677-677: f-string without any placeholders

Remove extraneous f prefix

(F541)


678-678: Do not catch blind exception: Exception

(BLE001)


701-701: Do not catch blind exception: Exception

(BLE001)


739-739: f-string without any placeholders

Remove extraneous f prefix

(F541)


747-747: Do not catch blind exception: Exception

(BLE001)


763-763: Do not catch blind exception: Exception

(BLE001)


779-779: Use raise without specifying exception name

Remove exception name

(TRY201)


796-796: Do not catch blind exception: Exception

(BLE001)


807-807: Probable insecure usage of temporary file or directory: "/tmp/lib:$PYTHONPATH"

(S108)


833-833: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


836-836: Dictionary key literal "PYTHONUNBUFFERED" repeated

Remove repeated key literal "PYTHONUNBUFFERED"

(F601)


842-842: Dictionary key literal "NCCL_DEBUG" repeated

Remove repeated key literal "NCCL_DEBUG"

(F601)


848-848: Dictionary key literal "TORCH_DISTRIBUTED_DEBUG" repeated

Remove repeated key literal "TORCH_DISTRIBUTED_DEBUG"

(F601)


894-894: Possible hardcoded password assigned to: "token"

(S105)


935-935: Do not catch blind exception: Exception

(BLE001)


948-948: Do not catch blind exception: Exception

(BLE001)


955-955: f-string without any placeholders

Remove extraneous f prefix

(F541)


963-963: f-string without any placeholders

Remove extraneous f prefix

(F541)


969-969: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (4)
examples/kft-v2/manifests/shared_pvc.yaml (1)

1-12: PVC looks correct for RWX NFS usage.

RWX + Filesystem + nfs-csi is appropriate for shared checkpoints.

examples/kft-v2/manifests/cluster_training_runtime.yaml (1)

6-10: CRD semantics and command injection verified

  • mlPolicy.numNodes defines the ML-level node count (e.g. 2 pods), and the trainer controller reconciles or overrides the template’s replicatedJobs[*].replicas to match.
  • Trainer.command and trainer.args are merged into the node container’s command/args at JobSet creation (e.g. torchrun flags are injected) so the base image will not idle.
examples/kft-v2/scripts/trl_training.py (1)

646-654: Verify TRL API: parameter name.

TRL SFTTrainer typically uses tokenizer=, not processing_class=. Confirm your TRL version supports processing_class or switch to tokenizer=tokenizer.

Would you like me to check the current TRL docs and adjust accordingly?

examples/kft-v2/trl-gpt2-checkpointing.ipynb (1)

718-726: No changes needed for processing_class. Confirmed that TRL’s SFTTrainer now accepts processing_class=tokenizer, so the existing code is correct.

Copy link
Member

@szaher szaher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @abhijeet-dhumal ! great work

I would patch the kubeflow sdk for now or install it from a branch instead of shipping the wheel directly.

since you already have the solution to expose the labels and annotations may be you can contribute it here and install directly from main

@abhijeet-dhumal abhijeet-dhumal requested a review from szaher October 6, 2025 16:30
@abhijeet-dhumal abhijeet-dhumal marked this pull request as ready for review October 6, 2025 16:31
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (11)
examples/kft-v2/scripts/trl_training.py (4)

10-11: Critical security issue flagged in previous review remains unaddressed.

The code still imports numpy.core.multiarray._reconstruct, registers it as a safe global (line 35), and patches torch.load to default weights_only=False (lines 189-195). This approach widens the RCE attack surface during model deserialization.

Based on previous review feedback, change to safer defaults.

Also applies to: 35-35, 189-195


442-451: Past issue remains: hard-coded Gloo backend loses GPU performance.

The code still hard-codes backend='gloo' (line 445), which was flagged in previous reviews. For GPU training, NCCL offers significantly better performance. See mnist.py lines 267-268 for the correct pattern:

device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")

466-468: Past issue remains: negative split size when dataset < 20 samples.

Lines 466-468 can produce negative train_size when len(full_dataset) < 20. This was flagged in previous reviews.

Apply the suggested fix:

-                    train_size = min(100, len(full_dataset) - 20)
-                    train_dataset = full_dataset.select(range(train_size))
-                    test_dataset = full_dataset.select(range(train_size, min(train_size + 20, len(full_dataset))))
+                    total = len(full_dataset)
+                    test_sz = min(20, max(0, total))
+                    train_sz = min(100, max(0, total - test_sz))
+                    train_dataset = full_dataset.select(range(train_sz))
+                    test_dataset = full_dataset.select(range(train_sz, min(train_sz + test_sz, total)))

577-578: Past issues remain: incorrect precision flags and hard-coded backend.

Two problems flagged in previous reviews:

  1. Lines 577-578: fp16 on CPU will crash; bf16 needs GPU support check
  2. Line 596: ddp_backend='gloo' hard-coded, losing GPU performance

Apply the precision fix:

-            'bf16': torch.cuda.is_available(),  # Only use bf16 if CUDA is available
-            'fp16': not torch.cuda.is_available(),  # Use fp16 for CPU training
+            'bf16': torch.cuda.is_available() and bool(getattr(torch.cuda, 'is_bf16_supported', lambda: False)()),
+            'fp16': False,  # Disable fp16 unless explicitly needed

And the backend fix:

-            'ddp_backend': 'gloo',  # Use gloo backend for better LoRA compatibility
+            'ddp_backend': 'nccl' if torch.cuda.is_available() else 'gloo',

Also applies to: 596-596

examples/kft-v2/scripts/mnist.py (4)

292-544: Critical: Nested function never called, training is a no-op.

The outer train_fashion_mnist() (line 2) defines an inner train_fashion_mnist() (line 292) but never calls it. When main() invokes the outer function, no training executes. This was flagged in previous reviews.

Rename the inner function and call it:

-    def train_fashion_mnist():
+    def _run_training():
         # Setup distributed training
         local_rank, global_rank, world_size, device_type = setup_distributed()
         ...
         # Finally clean up PyTorch distributed
         if world_size > 1:
             dist.destroy_process_group()
+    
+    # Run the actual training
+    _run_training()

176-179: Past issue remains: off-by-one in epoch display.

Line 176 adds 1 to epoch, but the docstring (line 65) and usage (line 516) indicate epoch is already the absolute 1-based epoch number, causing the display to show epoch+1.

Apply this fix:

-            epoch_num = epoch + 1
             total_epochs = self.total_epochs
-            message = f"Completed epoch {epoch_num}/{total_epochs}"
+            message = f"Completed epoch {epoch}/{total_epochs}"

304-307: Past issue remains: DDP missing device_ids for multi-GPU nodes.

Line 305 wraps the model without specifying device_ids and output_device, which can cause warnings on multi-GPU nodes. This was flagged in previous reviews.

Apply this fix:

         # Create model and wrap with DDP only if distributed
         net = Net().to(device)
         if world_size > 1:
-            model = nn.parallel.DistributedDataParallel(net)
+            if device.type == "cuda":
+                model = nn.parallel.DistributedDataParallel(
+                    net, device_ids=[local_rank], output_device=local_rank
+                )
+            else:
+                model = nn.parallel.DistributedDataParallel(net)
         else:
             model = net

467-467: Past issue remains: dist.get_world_size() raises when not initialized.

Line 467 calls dist.get_world_size() which will raise in single-process mode. Use the world_size variable from setup_distributed() instead, as flagged in previous reviews.

Apply this fix:

-                        world_size=dist.get_world_size(),
+                        world_size=world_size,
examples/kft-v2/trl-gpt2-checkpointing.ipynb (3)

123-169: Past issue remains: duplicate environment variable keys.

Lines 123, 155 duplicate PYTHONUNBUFFERED; lines 124, 161 duplicate NCCL_DEBUG; lines 125, 167 duplicate TORCH_DISTRIBUTED_DEBUG. Later values overwrite earlier ones, as flagged in previous reviews.

Remove the duplicates:

     "PYTHONUNBUFFERED": "1",
     "NCCL_DEBUG": "INFO",
     "TORCH_DISTRIBUTED_DEBUG": "INFO",
     "PYTHONPATH": "/tmp/lib:$PYTHONPATH",
     ...
     # Cache directories
-    "PYTHONUNBUFFERED": "1",
     "TRANSFORMERS_CACHE": "/workspace/cache/transformers",
     ...
     # Distributed training debug
-    "NCCL_DEBUG": "INFO",
     "NCCL_DEBUG_SUBSYS": "ALL",
     ...
-    "TORCH_DISTRIBUTED_DEBUG": "INFO",

183-193: Past issue remains: pip flags in packages list and missing commas.

Lines 191-192 include pip flags as list items, and line 190 is missing a comma, causing string concatenation. This was flagged in previous reviews.

Fix by removing pip flags and adding comma:

     packages_to_install=[
         "transformers[torch]",
         "trl", 
         "peft", 
         "datasets", 
         "accelerate",
         "torch",
-        "numpy"
-        " --target=/tmp/lib"
-        " --verbose"
+        "numpy",
     ],

If you need to pass additional pip arguments, check if Kubeflow SDK supports a dedicated parameter for pip flags.


244-252: Past security issues remain: hardcoded credentials and disabled TLS.

Lines 244-245 use placeholder strings for credentials that users might accidentally commit, and line 252 disables SSL verification. These were flagged in previous reviews.

Use environment variables:

+import os
-api_server = "<api-server-url>"
-token = "<auth-token>"
+api_server = os.getenv("KFP_API_SERVER", "")
+token = os.getenv("KFP_API_TOKEN", "")

 configuration = client.Configuration()
 configuration.host = api_server
 configuration.api_key = {"authorization": f"Bearer {token}"}

-# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
-configuration.verify_ssl = False
+# Only disable SSL verification if explicitly needed
+if os.getenv("KFP_VERIFY_SSL", "true").lower() == "false":
+    configuration.verify_ssl = False
🧹 Nitpick comments (5)
examples/kft-v2/scripts/trl_training.py (3)

37-188: Extract ProgressionTracker to eliminate duplication.

ProgressionTracker is nearly identical in trl_training.py (lines 37-188) and mnist.py (lines 21-225), with only minor differences in handling global_step and ETA formatting. Extract to a shared utility module to improve maintainability.


556-556: Remove unused variable.

Line 556: checkpoint_interval is assigned but never used.

Apply this diff:

         checkpoint_dir = Path(os.getenv('CHECKPOINT_URI', '/workspace/checkpoints'))
         checkpoint_enabled = os.getenv('CHECKPOINT_ENABLED', 'false').lower() == 'true'
-        checkpoint_interval = os.getenv('CHECKPOINT_INTERVAL', '30s')
         max_checkpoints = int(os.getenv('CHECKPOINT_MAX_RETAIN', '5'))

606-606: Remove duplicate import.

Line 606 re-imports os, which was already imported at line 4.

Apply this diff:

     """Training function."""
-
-    import os
examples/kft-v2/scripts/mnist.py (1)

1-1: Make file executable if shebang is intended.

Line 1 has a shebang but the file is not executable. Either make it executable with chmod +x or remove the shebang if not needed.

examples/kft-v2/trl-gpt2-checkpointing.ipynb (1)

171-171: Add import clarification for better UX.

Line 171 imports trl_train from trl_training.py. Consider adding a cell comment explaining that users need the scripts/trl_training.py file in their Python path, or show how to add it:

import sys
sys.path.append('./scripts')
from trl_training import trl_train
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 09204ae and c36f120.

⛔ Files ignored due to path filters (10)
  • examples/kft-v2/docs/01.png is excluded by !**/*.png
  • examples/kft-v2/docs/02.png is excluded by !**/*.png
  • examples/kft-v2/docs/03.png is excluded by !**/*.png
  • examples/kft-v2/docs/04.png is excluded by !**/*.png
  • examples/kft-v2/docs/05.png is excluded by !**/*.png
  • examples/kft-v2/docs/06.png is excluded by !**/*.png
  • examples/kft-v2/docs/07.png is excluded by !**/*.png
  • examples/kft-v2/docs/jobs.png is excluded by !**/*.png
  • examples/kft-v2/docs/trainjob_pods.png is excluded by !**/*.png
  • examples/kft-v2/docs/trainjobs_jobsets.png is excluded by !**/*.png
📒 Files selected for processing (6)
  • examples/kft-v2/README.md (1 hunks)
  • examples/kft-v2/manifests/cluster_training_runtime.yaml (1 hunks)
  • examples/kft-v2/manifests/shared_pvc.yaml (1 hunks)
  • examples/kft-v2/scripts/mnist.py (1 hunks)
  • examples/kft-v2/scripts/trl_training.py (1 hunks)
  • examples/kft-v2/trl-gpt2-checkpointing.ipynb (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/kft-v2/manifests/shared_pvc.yaml
  • examples/kft-v2/manifests/cluster_training_runtime.yaml
🧰 Additional context used
🧬 Code graph analysis (2)
examples/kft-v2/scripts/mnist.py (2)
examples/kft-v2/scripts/trl_training.py (7)
  • ProgressionTracker (37-187)
  • update_step (60-105)
  • get_checkpoint_number (80-84)
  • get_checkpoint_number (128-132)
  • write_status (143-187)
  • update_epoch (107-141)
  • setup_distributed (424-453)
tests/kfto/resources/mnist.py (1)
  • train (136-140)
examples/kft-v2/scripts/trl_training.py (1)
examples/kft-v2/scripts/mnist.py (7)
  • ProgressionTracker (21-225)
  • update_step (54-125)
  • get_checkpoint_number (87-97)
  • get_checkpoint_number (155-164)
  • write_status (181-225)
  • update_epoch (127-179)
  • setup_distributed (246-290)
🪛 markdownlint-cli2 (0.18.1)
examples/kft-v2/README.md

69-69: Images should have alternate text (alt text)

(MD045, no-alt-text)


73-73: Images should have alternate text (alt text)

(MD045, no-alt-text)


77-77: Images should have alternate text (alt text)

(MD045, no-alt-text)


81-81: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


83-83: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


85-85: Images should have alternate text (alt text)

(MD045, no-alt-text)


87-87: Images should have alternate text (alt text)

(MD045, no-alt-text)


95-95: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


99-99: Images should have alternate text (alt text)

(MD045, no-alt-text)


144-144: Images should have alternate text (alt text)

(MD045, no-alt-text)

🪛 Ruff (0.13.3)
examples/kft-v2/scripts/mnist.py

1-1: Shebang is present but file is not executable

(EXE001)


45-45: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


58-58: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


59-59: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


127-127: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


224-224: Do not catch blind exception: Exception

(BLE001)


285-285: Do not catch blind exception: Exception

(BLE001)


294-294: Unpacked variable global_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


340-340: Do not catch blind exception: Exception

(BLE001)

examples/kft-v2/scripts/trl_training.py

51-51: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


60-60: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


107-107: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


169-169: Multiple statements on one line (colon)

(E701)


170-170: Multiple statements on one line (colon)

(E701)


171-171: Multiple statements on one line (colon)

(E701)


172-172: Multiple statements on one line (colon)

(E701)


186-186: Do not catch blind exception: Exception

(BLE001)


191-191: Multiple statements on one line (colon)

(E701)


192-192: Multiple statements on one line (colon)

(E701)


241-241: Do not catch blind exception: Exception

(BLE001)


272-273: try-except-pass detected, consider logging the exception

(S110)


272-272: Do not catch blind exception: Exception

(BLE001)


272-272: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


283-283: Do not catch blind exception: Exception

(BLE001)


291-291: Do not catch blind exception: Exception

(BLE001)


295-295: Unused method argument: signum

(ARG002)


295-295: Unused method argument: frame

(ARG002)


302-302: Unused method argument: args

(ARG002)


302-302: Unused method argument: state

(ARG002)


302-302: Unused method argument: control

(ARG002)


302-302: Unused method argument: kwargs

(ARG002)


318-318: Do not catch blind exception: Exception

(BLE001)


321-321: Unused method argument: kwargs

(ARG002)


370-370: Unused method argument: args

(ARG002)


370-370: Unused method argument: control

(ARG002)


370-370: Unused method argument: kwargs

(ARG002)


383-383: Unused method argument: args

(ARG002)


383-383: Unused method argument: control

(ARG002)


383-383: Unused method argument: kwargs

(ARG002)


401-401: Unused method argument: control

(ARG002)


401-401: Unused method argument: kwargs

(ARG002)


417-417: Do not catch blind exception: Exception

(BLE001)


450-450: Do not catch blind exception: Exception

(BLE001)


470-470: Consider moving this statement to an else block

(TRY300)


471-471: Do not catch blind exception: Exception

(BLE001)


508-508: Consider moving this statement to an else block

(TRY300)


510-510: Do not catch blind exception: Exception

(BLE001)


556-556: Local variable checkpoint_interval is assigned to but never used

Remove assignment to unused variable checkpoint_interval

(F841)


606-606: Redefinition of unused os from line 4

Remove definition: os

(F811)


608-608: Unpacked variable local_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


614-614: Do not catch blind exception: Exception

(BLE001)


627-627: Unpacked variable script_args is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


663-663: f-string without any placeholders

Remove extraneous f prefix

(F541)


668-668: f-string without any placeholders

Remove extraneous f prefix

(F541)


669-669: Do not catch blind exception: Exception

(BLE001)


692-692: Do not catch blind exception: Exception

(BLE001)


730-730: f-string without any placeholders

Remove extraneous f prefix

(F541)


738-738: Do not catch blind exception: Exception

(BLE001)


754-754: Do not catch blind exception: Exception

(BLE001)


770-770: Use raise without specifying exception name

Remove exception name

(TRY201)


787-787: Do not catch blind exception: Exception

(BLE001)

examples/kft-v2/trl-gpt2-checkpointing.ipynb

12-12: Probable insecure usage of temporary file or directory: "/tmp/lib:$PYTHONPATH"

(S108)


38-38: Probable insecure usage of temporary file or directory: "/tmp/training_progression.json"

(S108)


41-41: Dictionary key literal "PYTHONUNBUFFERED" repeated

Remove repeated key literal "PYTHONUNBUFFERED"

(F601)


47-47: Dictionary key literal "NCCL_DEBUG" repeated

Remove repeated key literal "NCCL_DEBUG"

(F601)


53-53: Dictionary key literal "TORCH_DISTRIBUTED_DEBUG" repeated

Remove repeated key literal "TORCH_DISTRIBUTED_DEBUG"

(F601)


102-102: Possible hardcoded password assigned to: "token"

(S105)


143-143: Do not catch blind exception: Exception

(BLE001)


156-156: Do not catch blind exception: Exception

(BLE001)


163-163: f-string without any placeholders

Remove extraneous f prefix

(F541)


171-171: f-string without any placeholders

Remove extraneous f prefix

(F541)


177-177: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (6)
examples/kft-v2/scripts/trl_training.py (3)

275-285: LGTM: Robust distributed SIGTERM handling.

The callback properly initializes a distributed tensor for coordinating SIGTERM signals across ranks and gracefully falls back to local handling if distributed is unavailable.


641-670: Good: DDP static graph fix for LoRA compatibility.

The code applies _set_static_graph() to address DDP parameter marking issues with LoRA adapters, which is appropriate for this use case.


757-772: Good: Robust checkpoint retry logic.

The training loop properly attempts to resume from checkpoint and falls back to training from scratch if checkpoint loading fails. The error handling is appropriate.

examples/kft-v2/README.md (1)

182-189: LGTM: Reference links and documentation structure.

The reference section provides helpful links to official documentation for Kubeflow, TRL, PEFT, PyTorch, and OpenShift AI.

examples/kft-v2/trl-gpt2-checkpointing.ipynb (2)

335-354: LGTM: Robust log retrieval with error handling.

The cell properly handles potential errors when retrieving logs and provides a helpful message if logs aren't available yet.


397-426: LGTM: Cleanup function with status verification.

The cleanup properly retrieves and displays final job status before deletion, helping users verify the training completed successfully.


* Access the OpenShift AI dashboard, for example from the top navigation bar menu:

![](./docs/01.png)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add descriptive alt text to images for accessibility.

Lines 69, 73, 77, 85, 87, 99, and 144 have images without alt text, which impacts accessibility for screen reader users.

Example fixes:

-![](./docs/01.png)
+![OpenShift AI dashboard navigation](./docs/01.png)

-![](./docs/02.png)
+![Data Science Projects page](./docs/02.png)

-![](./docs/07.png)
+![Training job execution example](./docs/07.png)

Also applies to: 73-73, 77-77, 85-85, 87-87, 99-99, 144-144

🧰 Tools
🪛 markdownlint-cli2 (0.18.1)

69-69: Images should have alternate text (alt text)

(MD045, no-alt-text)

🤖 Prompt for AI Agents
In examples/kft-v2/README.md around lines 69, 73, 77, 85, 87, 99, and 144 there
are image references using ![](path) with empty alt text; update each image tag
to include concise, descriptive alt text that summarizes the image content or
purpose (e.g., ![Diagram showing X] or ![Screenshot of Y]) so screen readers can
convey meaning; keep alt text short and informative, and ensure any
decorative-only images use alt="" intentionally while meaningful images get
descriptive text.

Comment on lines +109 to +142
```python
from scripts.mnist import train_fashion_mnist

# Configure training parameters
config = {
"epochs": 10,
"batch_size": 64,
"learning_rate": 0.001,
"checkpoint_dir": "/mnt/shared/checkpoints"
}

# Start training
train_fashion_mnist(config)
```

### **Example 2: TRL GPT-2 Fine-tuning**

Run the TRL training example:

```python
from scripts.trl_training import trl_train

# Configure TRL parameters
config = {
"model_name": "gpt2",
"dataset_name": "alpaca",
"lora_r": 16,
"lora_alpha": 32,
"max_seq_length": 512
}

# Start TRL training
trl_train(config)
```
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Incorrect API usage in quick start examples.

Both example code blocks show passing a config dict to the training functions, but:

  • train_fashion_mnist() (mnist.py:2) takes no parameters
  • trl_train() (trl_training.py:1) takes no parameters

Both read configuration from environment variables. Update the examples to show setting environment variables instead:

# Example 1: Fashion-MNIST Training
import os
os.environ['NUM_EPOCHS'] = '10'
os.environ['BATCH_SIZE'] = '64'
os.environ['LEARNING_RATE'] = '0.001'
os.environ['CHECKPOINT_DIR'] = '/mnt/shared/checkpoints'

from scripts.mnist import train_fashion_mnist
train_fashion_mnist()
# Example 2: TRL GPT-2 Fine-tuning
import os
os.environ['MODEL_NAME'] = 'gpt2'
os.environ['DATASET_NAME'] = 'tatsu-lab/alpaca'
os.environ['LORA_R'] = '16'
os.environ['LORA_ALPHA'] = '32'

from scripts.trl_training import trl_train
trl_train()
🤖 Prompt for AI Agents
In examples/kft-v2/README.md around lines 109 to 142 the quick-start examples
incorrectly pass a config dict to train_fashion_mnist() and trl_train(), but
both functions take no parameters and read settings from environment variables;
update the README examples to set the appropriate environment variables
(NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE, CHECKPOINT_DIR for MNIST; MODEL_NAME,
DATASET_NAME, LORA_R, LORA_ALPHA, etc. for TRL) via os.environ before importing
and calling the functions, and remove the config dict/arguments so the examples
reflect the actual API usage.

@@ -0,0 +1,141 @@
apiVersion: trainer.kubeflow.org/v1alpha1
kind: ClusterTrainingRuntime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should aim to use one of the pre-installed ClusterTrainingRuntimes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti Yes that was the plan but at the moment the SDK doesn't allow providing volume mount specs for trainjob.. so I had to explicitly add all the needed configs/ env variables and volume mounts in the ClusterTrainingRuntime itself..
Do you mean I should add all the needed config in default torch-cuda-251 runtime and then use it as a reference for client.train method while creating a trainjob ?

Copy link
Contributor

@astefanutti astefanutti Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti Yes that was the plan but at the moment the SDK doesn't allow providing volume mount specs for trainjob.. so I had to explicitly add all the needed configs/ env variables and volume mounts in the ClusterTrainingRuntime itself..

@abhijeet-dhumal right, that was my understanding.

Do you mean I should add all the needed config in default torch-cuda-251 runtime and then use it as a reference for client.train method while creating a trainjob ?

No, I mean we should try to fill the gaps, and see how we can improve the SDK flexibility. Could that be one of the option we are adding to the SDK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Once we have this options fix merged for SDK : kubeflow/sdk#91
I think we will be able to provide volume mounts as well as other configurations via PodSpecOverrides

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of adjusting this PR again promptly , will definitely share the results soon..
but This demo is totally adjusted wrt latest Kubeflow SDK version 0.1.0, so are we good to keep ClusterTrainingRuntime for now separate.. I will update the PVC creation flow to use kubernetes api

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This demo is totally adjusted wrt latest Kubeflow SDK version 0.1.0, so are we good to keep ClusterTrainingRuntime for now separate.

I'd be inclined to keep that PR open until we close this gap.

Comment on lines +129 to +136
" \"LEARNING_RATE\": \"5e-5\",\n",
" \"BATCH_SIZE\": \"1\",\n",
" \"MAX_EPOCHS\": \"3\",\n",
" \"WARMUP_STEPS\": \"5\",\n",
" \"EVAL_STEPS\": \"3\",\n",
" \"SAVE_STEPS\": \"2\",\n",
" \"LOGGING_STEPS\": \"2\",\n",
" \"GRADIENT_ACCUMULATION_STEPS\": \"2\",\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems some of those environment variables would better be training function arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I guess will use func_args parameter of Customtrainer to provide the the func args instead of providing these as a env vars..
Thanks a lot for pointing this out, on-it 🙌

" # Uncomment for GPU training:\n",
" # \"nvidia.com/gpu\": \"1\",\n",
" },\n",
" packages_to_install=[\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't those packages already installed in the runtime image?

@@ -0,0 +1,12 @@
apiVersion: v1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the Kubernetes Python SDK be used in the notebook to create that PVC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that will work too, on it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then ClusterTrainingRuntime can also be created, I guess using Kubernetes custom_resource api, is it ok?

Comment on lines +48 to +60
Create a shared persistent volume for checkpoint storage:

```bash
oc apply -f manifests/shared_pvc.yaml
```

### **3. Cluster Training Runtime Setup**

Apply the cluster training runtime configuration:

```bash
oc apply -f manifests/cluster_training_runtime.yaml
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should aim at avoiding any extra oc commands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then ClusterTrainingRuntime can also be created, I guess using Kubernetes custom_resource api, is it ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the examples would use the pre-installed ClusterTrainingRuntimes and users would not have to create one for each example.


### **Cluster Requirements**
- **OpenShift Cluster**: With OpenShift AI (RHOAI) 2.17+ installed
- **Required Components**: `dashboard`, `trainingoperator`, and `workbenches` enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer v2 component should probably be trainer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I will update it accoridngly, Thanks!
Thinking is it ok to put it that way, as we don't have this utility available yet 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this example be available before v2 is in RHOAI?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would likely confuse people to have this example available using the v1 component so it may be better to hold this PR and wait until v2 is in RHOAI, WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fiona-Waters @astefanutti Yeah that makes sense, I will keep it as a draft for now!
There are some upcoming changes in Kubeflow SDK which will further simplify overall workflow here,
Specially TrainJob Options implementation including PodSpecOverrides capability which will allow mounting volume mounts without a need of customising default ClusterTrainingRuntimes 👍
So it will inturn help reducing oc dependencies expected from user ✅

@abhijeet-dhumal abhijeet-dhumal marked this pull request as draft October 8, 2025 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants